import numpy as np
import scipy.signal

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions.normal import Normal

from spinup.exploration.nets_exploration import (
    OffPolicyCoherentLinear,
    PSNELinear,
)


def combined_shape(length, shape=None):
    if shape is None:
        return (length,)
    return (length, shape) if np.isscalar(shape) else (length, *shape)


def mlp(sizes, activation, output_activation=nn.Identity):
    layers = []
    for j in range(len(sizes) - 1):
        act = activation if j < len(sizes) - 2 else output_activation
        layers += [nn.Linear(sizes[j], sizes[j + 1]), act()]
    return nn.Sequential(*layers)


def coherent_mlp(
    sizes, activation, beta=0.01, diag_std_w_init=0.017, output_activation=nn.Identity
):
    layers = []
    for j in range(len(sizes) - 2):
        act = activation if j < len(sizes) - 2 else output_activation
        layers += [nn.Linear(sizes[j], sizes[j + 1]), act()]
    layers += [
        OffPolicyCoherentLinear(sizes[-2], sizes[-1], beta, diag_std_w_init),
        output_activation(),
    ]

    return nn.Sequential(*layers)


def PSNE_mlp(sizes, activation, output_activation=nn.Identity):
    layers = []
    for j in range(len(sizes) - 1):
        act = activation if j < len(sizes) - 2 else output_activation
        layers += [PSNELinear(sizes[j], sizes[j + 1]), act()]
    return nn.Sequential(*layers)


def count_vars(module):
    return sum([np.prod(p.shape) for p in module.parameters()])


LOG_STD_MAX = 2
LOG_STD_MIN = -20


class SquashedGaussianMLPActor(nn.Module):
    def __init__(
        self,
        obs_dim,
        act_dim,
        hidden_sizes,
        activation,
        act_limit,
        exploration="action",
        beta=0.01,
        diag_std_w_init=0.017,
    ):
        super().__init__()
        self.exploration = exploration
        if exploration == "action":
            self.net = mlp([obs_dim] + list(hidden_sizes), activation, activation)
            self.mu_layer = nn.Linear(hidden_sizes[-1], act_dim)
            self.log_std_layer = nn.Linear(hidden_sizes[-1], act_dim)
        elif exploration == "coherent":
            pi_sizes = [obs_dim] + list(hidden_sizes) + [act_dim]
            self.mu_layer = coherent_mlp(pi_sizes, activation, beta, diag_std_w_init)
        elif exploration == "PSNE":
            pi_sizes = [obs_dim] + list(hidden_sizes) + [act_dim]
            self.mu_layer = PSNE_mlp(pi_sizes, activation)
        else:
            raise TypeError("Legit types of exploration: action, coherent or PSNE!")
        self.act_limit = act_limit

    def forward(self, obs, deterministic=False, with_logprob=True):
        if self.exploration == "action":
            net_out = self.net(obs)
            mu = self.mu_layer(net_out)
            log_std = self.log_std_layer(net_out)
            log_std = torch.clamp(log_std, LOG_STD_MIN, LOG_STD_MAX)
            std = torch.exp(log_std)
        else:
            mu = self.mu_layer(obs)
            std = 0.1

        # Pre-squash distribution and sample
        pi_distribution = Normal(mu, std)
        if deterministic:
            # Only used for evaluating policy at test time.
            pi_action = mu
        else:
            pi_action = pi_distribution.rsample()

        if with_logprob:
            # Compute logprob from Gaussian, and then apply correction for Tanh squashing.
            # NOTE: The correction formula is a little bit magic. To get an understanding
            # of where it comes from, check out the original SAC paper (arXiv 1801.01290)
            # and look in appendix C. This is a more numerically-stable equivalent to Eq 21.
            # Try deriving it yourself as a (very difficult) exercise. :)
            logp_pi = pi_distribution.log_prob(pi_action).sum(axis=-1)
            logp_pi -= (2 * (np.log(2) - pi_action - F.softplus(-2 * pi_action))).sum(
                axis=1
            )
        else:
            logp_pi = None

        pi_action = torch.tanh(pi_action)
        pi_action = self.act_limit * pi_action

        return pi_action, logp_pi


class MLPQFunction(nn.Module):
    def __init__(self, obs_dim, act_dim, hidden_sizes, activation):
        super().__init__()
        self.q = mlp([obs_dim + act_dim] + list(hidden_sizes) + [1], activation)

    def forward(self, obs, act):
        q = self.q(torch.cat([obs, act], dim=-1))
        return torch.squeeze(q, -1)  # Critical to ensure q has right shape.


class MLPActorCritic(nn.Module):
    def __init__(
        self,
        observation_space,
        action_space,
        hidden_sizes=(256, 256),
        activation=nn.ReLU,
        exploration="action",
        beta=0.01,
        diag_std_w_init=0.017,
    ):
        super().__init__()

        obs_dim = observation_space.shape[0]
        act_dim = action_space.shape[0]
        act_limit = action_space.high[0]

        # build policy and value functions
        self.pi = SquashedGaussianMLPActor(
            obs_dim,
            act_dim,
            hidden_sizes,
            activation,
            act_limit,
            exploration,
            beta,
            diag_std_w_init,
        )
        self.q1 = MLPQFunction(obs_dim, act_dim, hidden_sizes, activation)
        self.q2 = MLPQFunction(obs_dim, act_dim, hidden_sizes, activation)

    def act(self, obs, deterministic=False):
        with torch.no_grad():
            a, _ = self.pi(obs, deterministic, False)
            return a.numpy()

    def adapt_param_noise(self, mse_if_with_noise):
        if self.pi.exploration == "coherent":
            self.pi.mu_layer[-2].adapt_std_w(mse_if_with_noise)
        elif self.pi.exploration == "PSNE":
            for i in range(len(self.pi.mu_layer)):
                if isinstance(self.pi.mu_layer[i], PSNELinear):
                    self.pi.mu_layer[i].adapt_std_w(mse_if_with_noise, False)

    def set_pi_if_with_noise(self, if_with_noise):
        if self.pi.exploration == "coherent":
            self.pi.mu_layer[-2].if_with_noise = if_with_noise
        elif self.pi.exploration == "PSNE":
            for i in range(len(self.pi.mu_layer)):
                if isinstance(self.pi.mu_layer[i], PSNELinear):
                    self.pi.mu_layer[i].if_with_noise = if_with_noise

    def reset(self):
        if self.pi.exploration == "coherent":
            self.pi.mu_layer[-2].reset()
        elif self.pi.exploration == "PSNE":
            for i in range(len(self.pi.mu_layer)):
                if isinstance(self.pi.mu_layer[i], PSNELinear):
                    self.pi.mu_layer[i].reset()
